import numpy as np
import lazy_property
from scipy.stats import norm
from functools import reduce
from operator import add

from .data_helper import cols_by_date, list_dates


class Constraint:
    def __init__(self, list_bounds):
        self._bounds = np.array(list_bounds)
        self.l_bound = self._bounds[:, 0]
        self.u_bound = self._bounds[:, 1]

    def is_satisfied(self, x, options=None):
        options = Likelihood.set_default_options(options)
        return np.all(np.greater_equal(x, self.l_bound)
                      & np.greater_equal(self.u_bound, x))


def stable_norm(x, order=2):
    x = np.abs(x)
    m = np.max(x)
    x = x/m
    return (m/len(x)) * np.linalg.norm(x, order)


class Likelihood:
    PARAM_NAMES = (
        'some_payment',
        'payment',
        'priority_G1', 'priority_G2', 'priority_G3',
        'action_valor', 'action_rec1', 'action_medida',
         'covariate', 'G1_rec1', 'G1_medida',
        's_shape_low', 's_shape_high',
        'type_sdev'
    )
    PARAM_NAMES_TT = PARAM_NAMES[0:11] + ('time', ) + PARAM_NAMES[11:]
    PARAM_NAMES_VR = PARAM_NAMES[0:11] + ('action_valor_late', ) + PARAM_NAMES[11:]
    PARAM_NAMES_TT_VR = PARAM_NAMES[0:11] + ('time', 'action_valor_late',) + PARAM_NAMES[11:]
    event_dates = list(list_dates[:-1])

    def __init__(self, df_status, constraint, num_types=10,
                 covariate='prob_repayment_exo_covariates',
                 Q1_only=False, options=None,
                 max_payment_threshold=None, timetrend=False,
                 valor_regime=False,
                 interaction_with_G1_deadline=False, multicov_shareby3=False,
                 pre_post_july=False, pre_post_june=False,
                 multicov_shareby3_age=False, multicov_shareby3_age_missing=False, with_calls_data=False):
        self._covariate = covariate
        self._suggested_options = options
        self.timetrend = timetrend
        self.valor_regime = valor_regime
        self.Q1_only = Q1_only
        self.interaction_with_G1_deadline = interaction_with_G1_deadline
        self.multicov_shareby3 = multicov_shareby3
        self.multicov_shareby3_age = multicov_shareby3_age
        self.multicov_shareby3_age_missing = multicov_shareby3_age_missing
        self.pre_post_july = pre_post_july
        self.pre_post_june = pre_post_june
        self.with_calls_data = with_calls_data
        self.df_status = df_status[self.all_cols]
        self.df_status.loc[:, self.payment_cols] = \
            self.df_status[self.payment_cols].divide(
                df_status['total_due'], axis=0)
        self.constraint = constraint
        self._num_types = num_types
        self.max_payment_threshold = max_payment_threshold
        if self.Q1_only and (self.max_payment_threshold is None):
            self.max_payment_threshold = 1
        elif self.max_payment_threshold is None:
            self.max_payment_threshold = 4
        if self.timetrend and not self.valor_regime:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES_TT
        elif self.valor_regime and not self.timetrend:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES_VR
        elif self.valor_regime and self.timetrend:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES_TT_VR
        elif self.multicov_shareby3:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "share_repaid_by_3",
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.multicov_shareby3_age:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "share_repaid_by_3", "age",
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.multicov_shareby3_age_missing:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "share_repaid_by_3", "age", "age_missing", "last_year_share_repaid_missing"
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.interaction_with_G1_deadline:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "above_med_G1_deadline", "above_med_G1_deadline_rec1", "above_med_G1_deadline_medida"
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.pre_post_july:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "post_july_G1", "post_july_rec1", "post_july_G1_rec1"
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.pre_post_june:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "post_june_G1", "post_june_rec1", "post_june_G1_rec1"
            ) + Likelihood.PARAM_NAMES[9:]
        elif self.with_calls_data:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES[0:9] + (
                "number_of_calls",
            ) + Likelihood.PARAM_NAMES[9:]
        else:
            self.PARAM_NAMES = Likelihood.PARAM_NAMES
        if max(self.timetrend, self.valor_regime) > 1:
            raise ValueError('Cannot yet handle timetrend, early/late valor regimes, and interactions simultaneously')

    @lazy_property.LazyProperty
    def options(self):
        return self.set_default_options(self._suggested_options)

    @staticmethod
    def set_default_options(options):
        options = {} if options is None else options
        if 'sshape' not in options:
            options['sshape'] = 'linear'
        if 'param_transform' not in options:
            options['param_transform'] = lambda x: x
        return options

    @lazy_property.LazyProperty
    def num_periods(self):
        return len(self.event_dates)

    @lazy_property.LazyProperty
    def num_params(self):
        return len(self.PARAM_NAMES)

    @lazy_property.LazyProperty
    def normalized_types(self):
        return norm.ppf([i / self._num_types
                         for i in range(1, self._num_types)])

    @lazy_property.LazyProperty
    def cols_by_date(self):
        if isinstance(self._covariate, list):
            return [cols_by_date(d, with_calls_data=self.with_calls_data) + self._covariate for d in self.event_dates]
        else:
            return [cols_by_date(d, with_calls_data=self.with_calls_data) + [self._covariate] for d in self.event_dates]

    @lazy_property.LazyProperty
    def all_cols(self):
        return reduce(add, self.cols_by_date)

    @lazy_property.LazyProperty
    def payment_cols(self):
        return [c[1] for c in self.cols_by_date]

    def event_to_vect(self, event):
        if self.interaction_with_G1_deadline or self.multicov_shareby3:
            _, payment, priority, action, cov1, cov2 = event
        elif self.with_calls_data:
            _, payment, priority, action, calls, cov1 = event
        elif self.multicov_shareby3_age:
            _, payment, priority, action, cov1, cov2, cov3 = event
        elif self.multicov_shareby3_age_missing:
            _, payment, priority, action, cov1, cov2, cov3, cov4, cov5 = event
        else:
            _, payment, priority, action, cov = event
        some_payment = payment > 0
        priority_dummy = [priority == p for p in ['G1', 'G2', 'G3']]
        action_dummy = [action == a for a in ['valor', 'rec1', 'medida']]
        interactions_output = []
        interactions_output = [priority_dummy[0] * action_dummy[1],
                               priority_dummy[0] * action_dummy[2]]
        if self.interaction_with_G1_deadline:
            return_array = [some_payment, payment] + priority_dummy + action_dummy + [cov1, cov2 * priority_dummy[0]] + interactions_output
            cov2_return_array = action_dummy[1:]
            cov2_return_array = [cov2 * ele for ele in cov2_return_array]
            return np.array(return_array + cov2_return_array)
        elif self.multicov_shareby3:
            return [some_payment, payment] + priority_dummy + action_dummy + [cov1,cov2] + interactions_output
        elif self.multicov_shareby3_age:
            return [some_payment, payment] + priority_dummy + action_dummy + [cov1, cov2, cov3] + interactions_output
        elif self.multicov_shareby3_age_missing:
            return [some_payment, payment] + priority_dummy + action_dummy + [cov1, cov2, cov3, cov4, cov5] + interactions_output
        elif self.with_calls_data:
            return [some_payment, payment] + priority_dummy + action_dummy + [cov1, calls] + interactions_output
        else:
            return np.array(
                [some_payment, payment] + priority_dummy + action_dummy + [cov] + interactions_output)

    def uniform_param_generator(self, size=1):
        params = np.random.uniform(
            low=self.constraint.l_bound,
            high=self.constraint.u_bound,
            size=(size, self.num_params))
        params[:, [-3, -2]] = np.sort(params[:, [-3, -2]], axis=1)
        return params

    def scaled_types(self, params):
        return params[-1] * self.normalized_types

    def s_shape(self, params):
        a, b = params[-3:-1]
        if self.options['sshape'] == 'linear':
            def _s_shape(x):
                return np.minimum(b - a, np.maximum(x - a, 1e-3))
        if self.options['sshape'] == 'logistic':
            def _s_shape(x):
                return np.divide(b, 1 + np.exp(-x + a))
        return _s_shape

    def log_likelihood_at_param(self, params):
        if (self.timetrend and not self.valor_regime) or  (not self.timetrend and self.valor_regime):
            if len(params) != 15:
                raise ValueError("Not the right number of parameters")
        elif self.timetrend and self.valor_regime:
            if len(params) != 16:
                raise ValueError("Not the right number of parameters")
        elif self.interaction_with_G1_deadline:
            if len(params) != 17:
                raise ValueError("Not the right number of parameters")
        elif self.multicov_shareby3:
            if len(params) != 15:
                raise ValueError("Not the right number of parameters")
        elif self.multicov_shareby3_age:
            if len(params) != 16:
                raise ValueError("Not the right number of parameters")
        elif self.multicov_shareby3_age_missing:
            if len(params) != 18:
                raise ValueError("Not the right number of parameters")
        elif self.pre_post_july or self.pre_post_june:
            if len(params) != 17:
                raise ValueError("Not the right number of parameters")
        elif self.with_calls_data:
            if len(params) != 15:
                raise ValueError("Not the right number of parameters")
        else:
            if len(params) != 14:
                raise ValueError("Not the right number of parameters")
        if not self.constraint.is_satisfied(params, self.options):
            return -np.infty
        params = self.options['param_transform'](
            np.array(params))
        params = self.expand_params(params)
        ind_log_lik = self.individual_log_likelihood(params)
        return np.mean(self.df_status.apply(ind_log_lik, axis=1))

    def expand_params(self, params):
        s_shape = self.s_shape(params)
        scaled_theta = np.array(self.scaled_types(params)).reshape(-1, 1)
        return (params, s_shape, scaled_theta)

    def individual_log_likelihood(self, params):
        def _individual_log_likelihood(row):
            return self.log_norm_prob_across_thetas(row, params)
        return _individual_log_likelihood

    def log_norm_prob_across_thetas(self, row, params):
        prob_by_thetas = self.geo_mean_prob_across_events(row, params)
        return np.log(stable_norm(prob_by_thetas, self.num_periods))

    def geo_mean_prob_across_events(self, row, params):
        probs = self.prob_at_events_by_type(row, params)
        return np.exp(np.mean(np.log(probs), axis=1))

    def prob_at_events_by_type(self, row, params):
        fw_payment = self.fw_payment_made(row)
        intensity = self.intensity_at_events_by_type(row, params)
        return fw_payment + np.multiply(1 - 2 * fw_payment, np.exp(-intensity))

    def intensity_at_events_by_type(self, row, params):
        _, s_shape, _ = params
        not_fully_paid = self.share_payment_made(row) < self.max_payment_threshold
        return np.maximum(
            1e-3, np.multiply(
                s_shape(self.score_events_by_type(row, params)), not_fully_paid)
        )

    def score_events_by_type(self, row, params):
        _, _, scaled_thetas = params
        this_scores = np.array(self.score_events(row, params))
        return scaled_thetas + this_scores

    def score_events(self, row, params):
        if self.interaction_with_G1_deadline or self.multicov_shareby3:
            vectorized_row = row.values.reshape(-1, 6)
        elif self.with_calls_data:
            vectorized_row = row.values.reshape(-1, 6)
        elif self.multicov_shareby3_age:
            vectorized_row = row.values.reshape(-1, 7)
        elif self.multicov_shareby3_age_missing:
            vectorized_row = row.values.reshape(-1, 9)
        else:
            vectorized_row = row.values.reshape(-1, 5)
        v_event = np.apply_along_axis(self.event_to_vect, 1, vectorized_row)
        if self.timetrend:
            v_event = [list(np.append(list_i, i)) for list_i, i in
                       zip(v_event, [j for j in range(1,self.num_periods+1)])]
        if self.valor_regime:
            v_event = [list(np.append(list_i, (i >= 8)*(list_i[5]))) for list_i, i in
                       zip(v_event, [j for j in range(1, self.num_periods + 1)])]
        if self.pre_post_july:
            v_event = [list(np.append(list_i, [(i >= 13)*(list_i[2]),
                                      (i >= 13)*(list_i[6]),
                                      (i >= 13)*(list_i[9])])) for list_i, i in
                       zip(v_event, [j for j in range(1, self.num_periods + 1)])]
        if self.pre_post_june:
            v_event = [list(np.append(list_i, [(i >= 8)*(list_i[2]),
                                      (i >= 8)*(list_i[6]),
                                      (i >= 8)*(list_i[9])])) for list_i, i in
                       zip(v_event, [j for j in range(1, self.num_periods + 1)])]
        return np.dot(v_event, params[0][:-3])

    def fw_payment_made(self, row):
        if self.interaction_with_G1_deadline or self.multicov_shareby3:
            return 1 * (row.values[::6] > 0)
        elif self.multicov_shareby3_age:
            return 1 * (row.values[::7] > 0)
        elif self.multicov_shareby3_age_missing:
            return 1 * (row.values[::9] > 0)
        elif self.with_calls_data:
            return 1 * (row.values[::6] > 0)
        else:
            return 1 * (row.values[::5] > 0)

    def share_payment_made(self, row):
        if self.interaction_with_G1_deadline or self.multicov_shareby3:
            return row.values[1::6]
        elif self.multicov_shareby3_age:
            return row.values[1::7]
        elif self.multicov_shareby3_age_missing:
            return row.values[1::9]
        elif self.with_calls_data:
            return row.values[1::6]
        else:
            return row.values[1::5]
